{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Hosting ONNX models with Amazon Elastic Inference\n",
    "\n",
    "Amazon Elastic Inference (EI) is a resource you can attach to your Amazon EC2 instances to accelerate your deep learning (DL) inference workloads. EI allows you to add inference acceleration to an Amazon SageMaker hosted endpoint or Jupyter notebook for a fraction of the cost of using a full GPU instance. It reduces the cost of running deep learning inference by up to 75%. \n",
    "\n",
    "Amazon EI provides support for Apache MXNet, TensorFlow and ONNX models. The [Open Neural Network Exchange](https://onnx.ai/) (ONNX) is an open standard format for deep learning models that enables interoperability between deep learning frameworks such as Apache MXNet, Caffe2, Microsoft Cognitive Toolkit(CNTK), PyTorch and more. This means that we can use any of these frameworks to train the model, export these pretrained models in ONNX format and then import them in MXNet for inference. For more information please visit: https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html\n",
    "\n",
    "In this example, we will use the ResNet-152v1 model from [Deep residual learning for image recognition](https://arxiv.org/abs/1512.03385). This model, alongside many others, can be found at the [ONNX Model Zoo](https://github.com/onnx/models).\n",
    "\n",
    "We will use the SageMaker Python SDK to host this ONNX model in SageMaker, and perform inference requests."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup\n",
    "\n",
    "First, we'll get the IAM execution role from our notebook environment, so that SageMaker can access resources in your AWS account later in the example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "isConfigCell": true
   },
   "outputs": [],
   "source": [
    "from sagemaker import get_execution_role\n",
    "\n",
    "role = get_execution_role()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The hosting script\n",
    "\n",
    "We'll need to provide a hosting script that can run on the SageMaker platform. This script will be invoked by SageMaker when we perform inference.\n",
    "\n",
    "The script we're using here implements two functions:\n",
    "\n",
    "* `model_fn()` - the SageMaker model server uses this function to load the model\n",
    "* `transform_fn()` - this function is for using the model to take the input and produce the output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "!pygmentize resnet152.py"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Preparing the model\n",
    "\n",
    "To create a SageMaker Endpoint, we'll first need to prepare the model to be used in SageMaker."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Downloading the model\n",
    "\n",
    "For this example, we will use a pre-trained ONNX model from the [ONNX Model Zoo](https://github.com/onnx/models), where you can find a collection of pre-trained models to work with. Here, we will download the [ResNet-152v1 model](https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet152v1/resnet152v1.onnx) trained on ImageNet dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import mxnet as mx\n",
    "mx.test_utils.download('https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet152v1/resnet152v1.onnx')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Compressing the model data\n",
    "\n",
    "Now that we have the model data locally, we will need to compress it and upload the tarball to S3 for the SageMaker Python SDK to create a Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import tarfile\n",
    "\n",
    "from sagemaker.session import Session\n",
    "\n",
    "with tarfile.open('onnx_model.tar.gz', mode='w:gz') as archive:\n",
    "    archive.add('resnet152v1.onnx')\n",
    "\n",
    "model_data = Session().upload_data(path='onnx_model.tar.gz', key_prefix='model')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Creating a SageMaker Python SDK Model instance\n",
    "\n",
    "With the model data uploaded to S3, we now have everything we need to instantiate a SageMaker Python SDK Model. We'll provide the constructor the following arguments:\n",
    "\n",
    "* `model_data`: the S3 location of the model data\n",
    "* `entry_point`: the script for model hosting that we looked at above\n",
    "* `role`: the IAM role used\n",
    "* `framework_version`: the MXNet version in use, in this case '1.4.0'\n",
    "\n",
    "You can read more about creating an `MXNetModel` object in the [SageMaker Python SDK API docs](https://sagemaker.readthedocs.io/en/latest/sagemaker.mxnet.html#mxnet-model)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from sagemaker.mxnet import MXNetModel\n",
    "\n",
    "mxnet_model = MXNetModel(model_data=model_data,\n",
    "                         entry_point='resnet152.py',\n",
    "                         role=role,\n",
    "                         framework_version='1.4.0')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Creating an inference endpoint and attaching an Elastic Inference(EI) accelerator\n",
    "\n",
    "Now we can use our `MXNetModel` object to build and deploy an `MXNetPredictor`. This creates a SageMaker Model and Endpoint, the latter of which we can use for performing inference. \n",
    "\n",
    "We pass the following arguments to the deploy method:\n",
    "\n",
    "* `instance_count` - how many instances to back the endpoint.\n",
    "* `instance_type` - which EC2 instance type to use for the endpoint. For information on supported instance, please check here.\n",
    "* `accelerator_type` - determines which EI accelerator type to attach to each of our instances. The supported types of accelerators can be found here: https://aws.amazon.com/sagemaker/pricing/instance-types/\n",
    "\n",
    "### How our models are loaded\n",
    "By default, the predefined SageMaker MXNet containers have a default `model_fn`, which determines how your model is loaded. The default `model_fn` loads an MXNet Module object with a context based on the instance type of the endpoint.\n",
    "\n",
    "This applies for EI as well. If an EI accelerator is attached to your endpoint and a custom `model_fn` isn't provided, then the default `model_fn` will load the MXNet Module object with an EI context, `mx.eia()`. This default `model_fn` works with the default save function. If a custom save function was defined, then you may need to write a custom model_fn function. For more information on the `model_fn`, see [SageMaker documentation](https://github.com/aws/sagemaker-python-sdk/blob/master/src/sagemaker/mxnet/README.rst#model-loading).\n",
    "\n",
    "### Choosing instance types\n",
    "We will deploy this model with instance type `ml.m5.xlarge` and `ml.eia1.medium`, and you can experiment with other instance types and EI sizes based on your model requirements. For this model, I found that it requires more CPU memory and so, I chose M5 instance with more memory compared to C5 instances as it is more cost effective. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "predictor = mxnet_model.deploy(initial_instance_count=1, instance_type='ml.m5.xlarge', accelerator_type='ml.eia1.medium')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Performing inference\n",
    "\n",
    "With our Endpoint deployed, we can now send inference requests to it. We'll use one image as an example here."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Preparing the image\n",
    "\n",
    "First, we'll download the image (and view it)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "# Download image for inference\n",
    "img_path = mx.test_utils.download('https://s3.amazonaws.com/onnx-mxnet/examples/mallard_duck.jpg')\n",
    "img = mx.image.imread(img_path)\n",
    "plt.imshow(img.asnumpy())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we will preprocess inference image. We will resize it to 256x256, take center crop of 224x224, normalize image, add a dimension to batchify the image."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from mxnet.gluon.data.vision import transforms\n",
    "def preprocess(img):\n",
    "    transform_fn = transforms.Compose([\n",
    "        transforms.Resize(256),\n",
    "        transforms.CenterCrop(224),\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
    "    ])\n",
    "    img = transform_fn(img)\n",
    "    img = img.expand_dims(axis=0)\n",
    "    return img\n",
    "input_image = preprocess(img)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Sending the inference request\n",
    "\n",
    "Now we can use the predictor object to classify the input image:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "scores = predictor.predict(input_image.asnumpy())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To see the inference result, let's download and load `synset.txt` file containing class labels for ImageNet. The top 5 classes generated in order, along with the probabilities are:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "mx.test_utils.download('https://s3.amazonaws.com/onnx-model-zoo/synset.txt')\n",
    "with open('synset.txt', 'r') as f:\n",
    "    labels = [l.rstrip() for l in f]\n",
    "\n",
    "a = np.argsort(scores)[::-1]\n",
    "\n",
    "for i in a[0:5]:\n",
    "        print('class=%s ; probability=%f' %(labels[i],scores[i]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "## Deleting the Endpoint\n",
    "\n",
    "Since we've reached the end, we'll delete the SageMaker Endpoint to release the instance associated with it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "predictor.delete_endpoint()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_mxnet_p36",
   "language": "python",
   "name": "conda_mxnet_p36"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  },
  "notice": "Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.  Licensed under the Apache License, Version 2.0 (the \"License\"). You may not use this file except in compliance with the License. A copy of the License is located at http://aws.amazon.com/apache2.0/ or in the \"license\" file accompanying this file. This file is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License."
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
